import gym
import highway_env
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
from torch.utils.tensorboard import SummaryWriter
# 创建环境
env_name = "roundabout-v0"
env = gym.make(env_name)


class RewardCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(RewardCallback, self).__init__(verbose)
        self.ep_rewards = []

    def _on_step(self) -> bool:
        rewards = self.model.rollout_buffer.rewards
        self.ep_rewards.append(np.mean(rewards))
        return True

    def _on_training_end(self) -> None:
        writer.close()


# 创建tensorboard writer
writer = SummaryWriter()

# 创建环境
env_name = "roundabout-v0"
env = gym.make(env_name)



model = A2C(policy="MlpPolicy",
            env=env, verbose=1,
            learning_rate=0.01,
            n_steps=10,
            gamma=0.95)

# 训练模型
model.learn(total_timesteps=int(1e2))
# 创建训练模型并传递callback
callback = [RewardCallback(),
            model.tensorboard_log(writer)]
# 评估模型
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)

# 保存模型
model.save("roundabout_A2C")

# 加载模型
model = A2C.load("roundabout_A2C")

# 使用模型进行预测

for f in range(40):
  done = truncated = False
  obs, info = env.reset()
  while not (done or truncated):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)#env.step(action.item(0))

    #print(reward)
    #print(info)
    #input("Press Enter to continue...")

    env.render()
    cur_frame = env.render(mode="rgb_array")
    # out.write(cur_frame)